#%%
import os
from tqdm import tqdm
import asyncio
import pickle
from pathlib import Path
from pathlib import PosixPath
import torch
import numpy as np
import pickle
from datasets import Dataset, DatasetDict
import json
MODEL_MODE =  "token_classification" # "sequence_classification"
torch.manual_seed(42)
np.random.seed(42)

sample_names = os.listdir(inference_code_path)#[:20]
sample_names = [sample_name.split(".")[0] for sample_name in sample_names]

# %%

def load_sample(sample_path):
    if ".html" == sample_path.suffix:
        with open(sample_path, "r") as f:
            code = f.read()
    elif ".json" == sample_path.suffix:
        with open(sample_path, "r") as f:
            code = json.load(f)["code"]
    else:
        raise Exception("Unknown file format")
    return code

def generate_sample(ground_code, predicted_code):
    return f"{predicted_code}\nGround: {ground_code}"

dataset_list = []
# %%
for idx , sample_name in tqdm(enumerate(sample_names), total=len(sample_names)):
    # dataset_list might not be empty, if so continue from where it was left off
    if len(dataset_list) > 0 and idx < len(dataset_list):
        continue
    priginal_code_path = dataset_json_base / f"{sample_name}.json"
    priginal_code = load_sample(priginal_code_path)
    inference_code_path = INFERENCE_BASE / "codes" / f"{sample_name}.html"
    inference_code = load_sample(inference_code_path)
    sample = {}
    sample["code"] = generate_sample(priginal_code, inference_code)
    sample["id"] = sample_name
    dataset_list.append(sample)

# %%

dataset = Dataset.from_list(dataset_list)

# %%

MAX_LENGTH = 1000
def tokenize_text(examples, tokenizer):
    outputs = tokenizer(
        examples["code"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors="pt",
    )
    # print("outputs")
    # print(type(outputs["input_ids"]))
    # print(type(outputs["attention_mask"]))

    examples["input_ids"] = outputs.input_ids
    examples
    examples["attention_mask"] = outputs.attention_mask
    if "label" in examples:
        if MODEL_MODE == "token_classification":
            # give each token label of "label"
            examples["labels"] = torch.tensor(examples["label"]).repeat(len(outputs.input_ids[0]), 1).T
    return examples

# %%

from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

# gpt 2 model
model_checkpoint = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
columns_to_remove = list(dataset.column_names).remove("id")

tokenized_dataset = dataset.map(lambda examples: tokenize_text(examples, tokenizer), batched=True, remove_columns=columns_to_remove)
tokenized_dataset.set_format("torch")
# %%


from transformers import AutoModelForTokenClassification
if MODEL_MODE == "sequence_classification":
    raise NotImplementedError
else:
    model = AutoModelForTokenClassification.from_pretrained("", num_labels=4)

# %%

model.eval()
model.to("cuda")

def collate_fn(examples):
    print("collate_fn")
    print(examples)
    input_ids = torch.stack([example["input_ids"] for example in examples])
    attention_mask = torch.stack([example["attention_mask"] for example in examples])
    return {"input_ids": input_ids, "attention_mask": attention_mask}

dataset_loader = torch.utils.data.DataLoader( tokenized_dataset, batch_size=8, shuffle=False)#, collate_fn=collate_fn)


# softmax to run logits through
softmax = torch.nn.Softmax(dim=1)
softmax.to("cuda")
logits_list = []
predictions_list = []
ids_list = []
for batch in tqdm(dataset_loader):

    input_ids = batch["input_ids"].to("cuda")
    attention_mask = batch["attention_mask"].to("cuda")
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        predictions = outputs.logits.argmax(-1)
        # most common prediction for each sample in batch
        prediction = predictions.mode(dim=1).values
        logits = outputs.logits
        logits_softmax = softmax(logits)
        logits_list.append(logits_softmax.detach().cpu().numpy())
        predictions_list.append(prediction.detach().cpu().numpy())
        ids_list.append(batch["id"])

# %%


# save it as dict structure {id: logits, id, prediction}
predictins_dict = {}
for i, ids in enumerate(ids_list):
    for j, id in enumerate(ids):
        predictins_dict[id] = {"logits": logits_list[i][j], "prediction": predictions_list[i][j]}


        

# %%

with open("critic_model_gpt2_token_pred_ground_predictions.pkl", "wb") as f:
    pickle.dump(predictins_dict, f)

# %%

# check if it is saved correctly
with open("critic_model_gpt2_token_pred_ground_predictions.pkl", "rb") as f:
    predictins_dict_loaded = pickle.load(f)

# check if loaded correctly
for key in predictins_dict_loaded.keys():
    if predictins_dict_loaded[key]["logits"].shape != predictins_dict[key]["logits"].shape:
        print("Error")
        break
# %%